Skip to content

[MLA] Fix nhead=32 non-persistent decode crash on gfx950#2983

Merged
valarLip merged 2 commits intoROCm:mainfrom
frida-andersson:fix/mla-nhead32-nonpersistent-crash
May 6, 2026
Merged

[MLA] Fix nhead=32 non-persistent decode crash on gfx950#2983
valarLip merged 2 commits intoROCm:mainfrom
frida-andersson:fix/mla-nhead32-nonpersistent-crash

Conversation

@frida-andersson
Copy link
Copy Markdown
Contributor

Summary

Fixes a GPU memory access fault when running MLA decode with nhead=32 (DeepSeek-V3.2 at TP4) in non-persistent mode on MI355X (gfx950).

Root Cause

Commit c849fd5 ("Add bf16 MLA decode kernel for gqa_ratio=64, qseqlen=1 (non-persistent)" #2729) zeroed ptr_RP and out_16_nosplit for all non-persistent dispatch. However, the legacy QH16 ASM kernel (MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co) used for nhead=32 still writes directly to the output buffer via ptr_RP when kv_split==1. Dereferencing nullptr causes:

Memory access fault by GPU node-X on address 0xNNNNNN. Reason: Write access to a read-only page.

This crashes during CUDA graph capture (decode, FULL).

Fix

C++ (csrc/py_itfs_cu/asm_mla.cu):

Python (aiter/mla.py):

  • Restore the bf16 nhead in [32, 64] early-return after stage1 when num_kv_splits==1. Without this, stage2 overwrites the kernel's direct output with garbage from the uninitialized split buffer.

Both changes match the behavior from v0.1.11 for the affected code paths.

Test

  • MI355X (gfx950), TP4, DeepSeek-V3.2
  • No crash during CUDA graph capture
  • GSM8K accuracy correct (0.94+)

@frida-andersson frida-andersson requested a review from a team April 30, 2026 14:38
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2983 --add-label <label>

Comment thread csrc/py_itfs_cu/asm_mla.cu Outdated
args.out_16_nosplit = 0;
args.ptr_RP = nullptr;
// Legacy QH16 ASM kernels (nhead=32/64, qseqlen=1) write directly to
// output via ptr_RP when kv_split==1. Passing nullptr causes GPU
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It really sounds to me like we are missing a test for this scenario in aiter. There should be a more transparent way of distinguishing kernels that write these but by the number of QHs.

@ChuanLi1101
Copy link
Copy Markdown
Contributor

ChuanLi1101 commented May 3, 2026

Independently verified on MI355X (gfx950) inside rocm/atom-dev:vllm-v0.19.0-nightly_20260422 (ROCm 7.2.2, torch 2.10.0+rocm7.2.2.git40d237bf, aiter HEAD 9522c04).

Repro path — exactly the failure mode this PR describes:

  • nhead=32, bf16/bf16, num_kv_splits=1, decode_qlen=1, non-persistent (no work_meta_data)
  • BEFORE the patch: Memory access fault by GPU node-2 ... address (nil) on the first call into mla_a16w16_qh16_m32x1_n16x1_coex0_mask1. Process aborts (exit 134).
  • AFTER the patch (after rebuilding module_mla_asm so the asm_mla.cu change is actually picked up): 12/12 cases pass for nhead in {32, 64} x ctx in {256, 1024} x bs in {1, 4, 16}, all numerically match the torch reference at atol=rtol=1e-2, latency 13-67 us. Same kernel as before, just with a non-null ptr_RP.
pr2983 nhead=32 ctx=256  bs=1  splits=1:    21.02 us  passed
pr2983 nhead=32 ctx=256  bs=4  splits=1:    21.82 us  passed
pr2983 nhead=32 ctx=256  bs=16 splits=1:    22.81 us  passed
pr2983 nhead=32 ctx=1024 bs=1  splits=1:    65.40 us  passed
pr2983 nhead=32 ctx=1024 bs=4  splits=1:    65.70 us  passed
pr2983 nhead=32 ctx=1024 bs=16 splits=1:    67.53 us  passed
pr2983 nhead=64 ctx=256  bs=1  splits=1:    13.28 us  passed
pr2983 nhead=64 ctx=256  bs=4  splits=1:    13.95 us  passed
pr2983 nhead=64 ctx=256  bs=16 splits=1:    15.11 us  passed
pr2983 nhead=64 ctx=1024 bs=1  splits=1:    35.15 us  passed
pr2983 nhead=64 ctx=1024 bs=4  splits=1:    36.00 us  passed
pr2983 nhead=64 ctx=1024 bs=16 splits=1:    38.28 us  passed

The diff is narrowly scoped: 19 lines / 2 files, only the non-persistent host wrapper (asm_mla.cu) and the matching nhead in [32, 64] branch in the python entrypoint. No kernel change, no impact on the persistent / sparse / unified-attention paths.

Could a maintainer apply the ready label so review CI can run? @sunway513 — given the v0.1.13 / vLLM 0.21 freeze, this looks like a low-risk addition to the #3005 bulk merge.

Note for anyone reproducing locally: AITER_REBUILD=1 alone won't pick up the asm_mla.cu change — mla_decode_stage1_asm_fwd uses the ctypes ffi path which only checks for the .so's existence. To force a rebuild after applying the patch:

rm -f aiter/jit/module_mla_asm.so
rm -rf aiter/jit/build/module_mla_asm
python op_tests/test_mla_nhead32_regression.py

sunway513 added a commit that referenced this pull request May 3, 2026
Reverting cherry-pick of #2983 from this bulk merge. The MLA nhead=32
non-persistent decode fix causes deterministic test_mla k_cache and
mla_decode-absorb precision failures on CI MI35X runners (Shard 1 & 2).

#2983 should go through its own PR with proper CI validation by the
original author (frida-andersson).
@frida-andersson frida-andersson force-pushed the fix/mla-nhead32-nonpersistent-crash branch from 3fe1ecd to c07ad8d Compare May 4, 2026 10:10
@frida-andersson
Copy link
Copy Markdown
Contributor Author

v2 update — rebased on latest main, narrowed the fix condition.

v1 → v2 change: gqa_ratio * max_seqlen_q <= 64 / nhead in [32, 64] incorrectly set ptr_RP for v3/stage1 kernels (gqa=8,16,64) that don't use it — this caused the test_mla precision failures @sunway513 saw. v2 narrows to the exact legacy kernel: gqa_ratio == 32 && max_seqlen_q == 1 / nhead == 32 and max_seqlen_q == 1.

Also narrows mgc and MAYBE_FINAL_OUT guards (were nhead in [8, 16] → now nhead == 16 only) and removes dead gqa_ratio == 8 non-persistent block.

Tested on MI355X (gfx950): nhead=8,16,32,64,128 all pass. bf16/bf16, ctx_lens=[256,1024], batch=[1,4,16].

@valarLip valarLip requested a review from fangche123 May 5, 2026 02:23
sunway513 added a commit that referenced this pull request May 5, 2026
…3-Next, pa_mqa OOB) (#3005)

* fix: remap QuantType.No to per_1x32 for fp4x2 MoE weights (W4A6 support)

* Fixing two cascading bugs when running the MoE tuner

* Enable split-K for block-scale A8W8 CK and CKTile GEMMs

Propagate the splitK parameter (as KBatch = 2^splitK) through the
block-scale GEMM kernel infrastructure so that the tuning scripts
can sweep split-K values to improve occupancy on small-M shapes.

CK path: add KBatch parameter to gemm_a8w8_blockscale_impl and call
SetKBatch on the device argument. The CK invoker handles output
zeroing and atomic accumulation internally.

CKTile path: add k_batch parameter to gemm_a8w8_blockscale_cktile_impl,
remove the "split-k is not supported yet" runtime guard, and add
hipMemsetAsync to zero the output buffer before atomic accumulation.

Non-tune entry points pass KBatch=1 (no split-K) to preserve existing
behavior. Code generation scripts (gen_instances.py, gen_instances_cktile.py)
updated to include the new parameter in generated wrappers and manifests.

Made-with: Cursor

* Wire splitK from tuning CSV through production blockscale GEMM dispatch

The tuning infrastructure already sweeps splitK and writes it to the CSV,
but the production dispatch ignored it and hardcoded KBatch=1. Add splitK
as a runtime parameter to the non-tune entry points so tuned split-K
values are used without compiling the full _tune instance set.

Made-with: Cursor

* fix: ck_moe_stage1 split-K output buffer overflow from padding scatter

The CK kernel scatters output via sorted_token_ids using:
  token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24)

Padding entries use the sentinel value (topk << 24 | token_num),
which decodes to scatter position (token_num * topk + topk) -- beyond
the valid output range [0, token_num * topk). The original buffer
(token_num, topk, w1.shape[1]) only has token_num * topk rows, so
the padding scatter writes out of bounds, causing "HIP runtime error:
invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode
with token_num=1, topk=8, block_m=16).

Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum
needed to absorb all padding scatter writes. After the kernel, slice
only the valid [0, token_num * topk) rows for the activation.

Related: #2508
Made-with: Cursor

* Address PR review feedback: validate splitK, fix hipMemset stride issue, add correctness test

Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/e3b37b0f-e151-4935-ad89-fd72436d41e2

Co-authored-by: samremes <181322991+samremes@users.noreply.github.com>

* black format

* fix splitk test dimensions

* Add gdn fusions

* style: fix ruff F841 and black-format Triton PR files

Remove unused variable in rmsnorm FP8 test ref. Apply Black to
kernels, launchers, tests, and gated_delta_rule decode __init__.

Made-with: Cursor

* Update fused_rearrange_sigmoid_gdr.py

* Update op_tests

* Fix BLACK format problem

* Fix black check failure

* Update test_fused_rearrange_sigmoid_gdr.py

* Allow callers to pass pre-allocated moe_buf to avoid output copy

Add an optional `moe_buf` parameter through the moe_sorting and
fused_moe call chain. When provided, the sorting kernel writes
directly into the caller's buffer instead of allocating a new one,
eliminating a redundant copy on the output path.

Made-with: Cursor

* Add moe_buf pass-through test to existing test_moe_sorting

Made-with: Cursor

* Replace _fast with _single_token for causal conv1d update kernels for single token decoding

* Fix blck format error

* Add tuned a8w8 blockscale GEMM config for Qwen3-Next-80B-A3B on MI355X

Tuned 1482 shapes (TP1/TP2/TP4) for Qwen/Qwen3-Next-80B-A3B-Instruct-FP8
on MI355X using CK + CK-TILE backends with splitK support.

Depends on:
- PR #2862 (CK bump for stride fix in CK-TILE blockscale)
- PR #2541 (splitK support for CK/CK-TILE blockscale GEMMs)
- PR #2487 (AQLayout tunable for CK-TILE blockscale 8-warp kernels)

* refactor(triton): rename gated RMSNorm+FP8 op to fused_rms_gated_fp8_group_quant

Colocate the gated RMSNorm + FP8 group quant path with the other fused FP8
ops. The Triton kernel is now _fused_rms_gated_fp8_group_quant_kernel in
_triton_kernels/quant/fused_fp8_quant.py; the Python entry point is
fused_rms_gated_fp8_group_quant in quant/fused_fp8_quant.py, with a docstring
that contrasts it with fused_rms_fp8_group_quant. Remove the old
rmsnorm_input_quant_fp8 module and rms_norm_input_quant_fp8 kernel file.
Re-export the new symbol and helpers (get_fp8_min_max_bounds,
calc_rows_per_block) from aiter.ops.triton.quant. Rename the test file to
test_fused_rms_gated_fp8_group_quant.py and update test.sh.

BREAKING CHANGE: rmsnorm_input_quant_fp8 is removed; use
fused_rms_gated_fp8_group_quant instead.

Made-with: Cursor

* Retune blockscale GEMM configs to fix invalid kernelId+splitK combinations

Full retune of all 1482 shapes on MI355X (gfx950, cu_num=256).
Key changes:
- SplitK usage dropped from 613 to 88 CK shapes (splitK > 0)
- All shapes validated via --run_config (1482/1482 OK)
- E2e perf: 2-8% output throughput improvement vs untuned heuristic

* [Bug] pa_mqa_logits: mask OOB stores on OutLogits_buffer

The gluon `_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle` and
`_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx` kernels have 10
`buffer_store(ptr=OutLogits_buffer, ...)` call sites that are missing the
upper-bound mask present on their sibling stores.  When
`context_length == max_model_len` (the last-token position in a long-
context decode step), `split_context_length` is rounded UP to a
`KVBlockSize` multiple at line 427 and the final prefix/suffix store then
writes up to `ChunkKPerStage` float32 elements past the logical row end.
With `stride_out_batch == max_model_len`, those writes cross into the
next row / the next allocation, causing intermittent HIP memory-access
faults on gfx950 during DeepSeek V3.2 MTP decoding.

This change adds `mask=<offset> < max_model_len` to every unmasked
`buffer_store` on `OutLogits_buffer` in both preshuffle kernels, matching
the pattern of their already-masked neighbours.  The existing
`tl.where(..., -inf)` masking of the *values* is preserved; the only
behavioural change is that out-of-row lanes no longer emit buffer
stores.  Hardware overhead is negligible: `buffer_store` with a predicate
is the same SMEM descriptor path as the unmasked variant, just with a
VCC mask setup.

Repro + end-to-end fix evidence: see PR description.

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>

* style: fix Black formatting

* style: fix Black formatting (Python 3.12 compatible)

* ci: replace deprecated zmq package with pyzmq

The `zmq` meta-package fails to install on some CI runners because
it cannot resolve the `pyzmq` dependency. Use `pyzmq` directly,
which is the actual package providing ZeroMQ bindings for Python.

Fixes Triton Test Shard 7 setup failures.

* ci: increase pip retries and timeout for CI reliability

Set pip global retries=15 and timeout=120s in build_aiter_triton.sh
to handle transient PyPI network failures on self-hosted runners.
Shard 5/7 failures were caused by RemoteDisconnected during pip install.

* ci: make pyzmq install non-blocking in triton test setup

pyzmq is only used by aiter.dist.shm_broadcast, not by any triton
test. When PyPI is unreachable on self-hosted runners, the pyzmq
install failure should not block the entire CI shard.

Split pyzmq into a separate pip install with || fallback so triton
tests can proceed even when PyPI connectivity is degraded.

* ci: retry pip install individually on batch failure

When batch pip install fails (e.g., PyPI connectivity issues on
self-hosted runners), retry each package individually. Only pyzmq
is allowed to fail silently since it's only used by
aiter.dist.shm_broadcast and not required by any CI test suite.

Critical packages (pandas, einops, numpy) must still succeed.

* [MLA] Fix nhead=32 non-persistent decode crash on gfx950

Commit c849fd5 ("Add bf16 MLA decode kernel for gqa_ratio=64,
qseqlen=1 (non-persistent)") zeroed ptr_RP and out_16_nosplit for all
non-persistent dispatch. The legacy QH16 ASM kernel used for nhead=32
(MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co) still writes
directly to the output buffer via ptr_RP when kv_split==1.
Dereferencing nullptr causes a GPU memory access fault during CUDA
graph capture on MI355X (gfx950) with DeepSeek-V3.2 at TP4.

Fix:
- Conditionally restore ptr_RP and out_16_nosplit in the non-persistent
  path for legacy kernels (gqa_ratio * max_seqlen_q <= 64) while
  keeping nullptr for newer kernels (e.g. gqa_ratio=64).
- Restore the bf16 nhead in [32,64] early-return after stage1 when
  num_kv_splits==1 to prevent stage2 from overwriting the kernel's
  direct output.

Tested on MI355X TP4 with deepseek-ai/DeepSeek-V3.2 (nhead=32):
- No crash during CUDA graph capture
- Correct GSM8K accuracy

Made-with: Cursor

* revert: remove #2983 (MLA nhead=32 fix) — causes test_mla CI failures

Reverting cherry-pick of #2983 from this bulk merge. The MLA nhead=32
non-persistent decode fix causes deterministic test_mla k_cache and
mla_decode-absorb precision failures on CI MI35X runners (Shard 1 & 2).

#2983 should go through its own PR with proper CI validation by the
original author (frida-andersson).

* fix: restore tuple unpack for FlyDSL fused-quant stage1 return

flydsl_moe_stage1 returns (out, out_scale_sorted) when the kernel uses
fused fp4/fp8 quantization. The tuple unpack logic was removed during
earlier refactoring but the kernel behavior was not changed, causing
fused_moe_2stages to crash with:
  AttributeError: 'tuple' object has no attribute 'view'

Restore the unpack: detect tuple return, extract tensor and scale,
handle fp4 byte-packing trim, and skip redundant Python-side requant
when the kernel already produced sorted scales.

* Revert leaked changes from excluded PRs #2457/#2547/#2687 in fused_moe.py

- Restore import to match main: use `from aiter import
  fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd` instead of
  importing from internal triton path and fp4_utils
- Replace all fp4_utils.moe_mxfp4_sort() calls with mxfp4_moe_sort_fwd()
  using correct parameter names (cols= instead of block_size=)
- Remove all moe_buf preallocated buffer additions (PR #2687 rejected):
  parameter defaults, if-guards, and pass-throughs in _moe_sorting_impl,
  moe_sorting, fused_moe, fused_moe_fake, and fused_moe_
- Fix moe_sorting_dispatch_policy type annotation: bool -> int in
  fused_moe_fake and fused_moe_
- Remove moe_buf pass-through test from test_moe_sorting.py
- Preserve legitimate fp4_utils usage (mxfp4_to_f32, e8m0_to_f32) with
  local imports in stage1/stage2 fallback functions

* fix: restore fp4_utils.moe_mxfp4_sort for new code paths (different output layout than mxfp4_moe_sort_fwd)

* style: fix Black formatting for local imports

* fix: remove rejected W4A6 QuantType remap from fused_moe_dp_shared_expert

Lingpeng explicitly rejected this change (from excluded PR #2457).
Reverts the QuantType.No -> per_1x32 remap for fp4x2 weights.

* fix: restore silently-reverted main features from bad merge resolution

aiter/fused_moe.py:
- Restore to origin/main. Per sunway513's own comment, #2457 and #2547
  were excluded from this bulk merge; per valarLip, #2687 was rejected.
  No source PR should land changes in this file. The previous state
  (+110/-119 vs main) was collateral damage from auto-resolved conflicts
  taking older sides, which silently reverted #2262 (xbf16 asm fmoe path),
  #2726 (FlyDSL a8w4 MoE wrapper params + fuse_quant), #2658 (CK fp8
  blockscale splitk tuner support), and #2620 (mxfp4_moe_sort_hip,
  flagged by valarLip).

op_tests/test_gemm_a8w8_blockscale.py:
- Replace with a clean 3-way merge of origin/main + #2541. Now +55/-0
  vs main, matching #2541's actual contribution exactly. The previous
  state was silently reverting #2645 (CK GEMM multi-arch + test infra:
  TEST_NUM_ITERS, --csv/--output args, kernel_name= param).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* chore: remove #2464 from bulk merge per author request

@xaguilar-amd asked to drop #2464 (CK MoE tuner bug fixes) from this
bulk merge — they don't need it for the uplift.

Verified that #2464 is the only PR in this bulk merge touching
aiter/jit/core.py and aiter/utility/mp_tuner.py: the diff between the
branch and origin/main on those files is exactly #2464's +9/-1 and
+5/-0, with no other PR content mixed in. Restoring both files to
origin/main therefore drops #2464 cleanly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: vecheruk-amd <vecheruk@amd.com>
Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com>
Co-authored-by: Sami Remes <samremes@amd.com>
Co-authored-by: Li <chuali@amd.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: samremes <181322991+samremes@users.noreply.github.com>
Co-authored-by: hellozhuo <zhuo.su@amd.com>
Co-authored-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com>
Co-authored-by: Niklas Holmberg <nholmber@users.noreply.github.com>
Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: frida-andersson <fanderss@amd.com>
Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@azaidy
Copy link
Copy Markdown
Contributor

azaidy commented May 5, 2026

@frida-andersson I assume this is a short lived temporary fix. We need an architecturally sound fix proposed/merged within 2 weeks.

cc @maeehart @sunway513

@maeehart
Copy link
Copy Markdown
Contributor

maeehart commented May 5, 2026

Yes, @azaidy, our long term fix is in #2983. However, we are not yet ready for merging it since it appears to have a crashing issue.

The legacy QH16 m32x1_n16x1 ASM kernel (gqa_ratio=32, bf16/bf16,
non-persistent, decode qseqlen=1) writes its output directly via
ptr_RP when kv_split==1. Upstream passes ptr_RP=nullptr and
out_16_nosplit=0, causing GPU memory faults on gfx950 (DeepSeek-V3.2
TP4 hits this with nhead=32).

Fix:
- C++: set ptr_RP and out_16_nosplit only when gqa_ratio==32 AND
  max_seqlen_q==1 (the exact legacy kernel condition). Other
  non-persistent kernels (v3, stage1) use split-reduce and expect
  ptr_RP = nullptr, so they are unaffected.
- Python: reuse output buffer for logits and skip stage2 only when
  nhead==32 and max_seqlen_q==1 (matches the C++ gate).

Tested on MI355X (gfx950): nhead=8,16,32,64,128 all pass.
bf16/bf16, ctx_lens=[256,1024], batch=[1,4,16].

Supersedes: ROCm#2999 (broken — tile mismatch, 85% wrong output)
Co-authored-by: Cursor <cursoragent@cursor.com>
@frida-andersson frida-andersson force-pushed the fix/mla-nhead32-nonpersistent-crash branch from 35f39b3 to b228670 Compare May 5, 2026 17:44
sunway513 added a commit that referenced this pull request May 5, 2026
Cherry-pick of b228670 from PR #2983 onto release/v0.1.13.
No conflicts.

Original PR: #2983
Copy link
Copy Markdown
Collaborator

@valarLip valarLip left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@valarLip valarLip merged commit e09effa into ROCm:main May 6, 2026
30 checks passed
@frida-andersson frida-andersson deleted the fix/mla-nhead32-nonpersistent-crash branch May 6, 2026 07:48
Liang-jianhao97 pushed a commit that referenced this pull request May 7, 2026
…3-Next, pa_mqa OOB) (#3005)

* fix: remap QuantType.No to per_1x32 for fp4x2 MoE weights (W4A6 support)

* Fixing two cascading bugs when running the MoE tuner

* Enable split-K for block-scale A8W8 CK and CKTile GEMMs

Propagate the splitK parameter (as KBatch = 2^splitK) through the
block-scale GEMM kernel infrastructure so that the tuning scripts
can sweep split-K values to improve occupancy on small-M shapes.

CK path: add KBatch parameter to gemm_a8w8_blockscale_impl and call
SetKBatch on the device argument. The CK invoker handles output
zeroing and atomic accumulation internally.

CKTile path: add k_batch parameter to gemm_a8w8_blockscale_cktile_impl,
remove the "split-k is not supported yet" runtime guard, and add
hipMemsetAsync to zero the output buffer before atomic accumulation.

Non-tune entry points pass KBatch=1 (no split-K) to preserve existing
behavior. Code generation scripts (gen_instances.py, gen_instances_cktile.py)
updated to include the new parameter in generated wrappers and manifests.

Made-with: Cursor

* Wire splitK from tuning CSV through production blockscale GEMM dispatch

The tuning infrastructure already sweeps splitK and writes it to the CSV,
but the production dispatch ignored it and hardcoded KBatch=1. Add splitK
as a runtime parameter to the non-tune entry points so tuned split-K
values are used without compiling the full _tune instance set.

Made-with: Cursor

* fix: ck_moe_stage1 split-K output buffer overflow from padding scatter

The CK kernel scatters output via sorted_token_ids using:
  token_offset = (fused_token & 0xffffff) * topk + (fused_token >> 24)

Padding entries use the sentinel value (topk << 24 | token_num),
which decodes to scatter position (token_num * topk + topk) -- beyond
the valid output range [0, token_num * topk). The original buffer
(token_num, topk, w1.shape[1]) only has token_num * topk rows, so
the padding scatter writes out of bounds, causing "HIP runtime error:
invalid argument" during CUDA graph capture (e.g. DeepSeek-R1 decode
with token_num=1, topk=8, block_m=16).

Fix: allocate (token_num * topk + topk + 1) rows -- the exact minimum
needed to absorb all padding scatter writes. After the kernel, slice
only the valid [0, token_num * topk) rows for the activation.

Related: #2508
Made-with: Cursor

* Address PR review feedback: validate splitK, fix hipMemset stride issue, add correctness test

Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/e3b37b0f-e151-4935-ad89-fd72436d41e2

Co-authored-by: samremes <181322991+samremes@users.noreply.github.com>

* black format

* fix splitk test dimensions

* Add gdn fusions

* style: fix ruff F841 and black-format Triton PR files

Remove unused variable in rmsnorm FP8 test ref. Apply Black to
kernels, launchers, tests, and gated_delta_rule decode __init__.

Made-with: Cursor

* Update fused_rearrange_sigmoid_gdr.py

* Update op_tests

* Fix BLACK format problem

* Fix black check failure

* Update test_fused_rearrange_sigmoid_gdr.py

* Allow callers to pass pre-allocated moe_buf to avoid output copy

Add an optional `moe_buf` parameter through the moe_sorting and
fused_moe call chain. When provided, the sorting kernel writes
directly into the caller's buffer instead of allocating a new one,
eliminating a redundant copy on the output path.

Made-with: Cursor

* Add moe_buf pass-through test to existing test_moe_sorting

Made-with: Cursor

* Replace _fast with _single_token for causal conv1d update kernels for single token decoding

* Fix blck format error

* Add tuned a8w8 blockscale GEMM config for Qwen3-Next-80B-A3B on MI355X

Tuned 1482 shapes (TP1/TP2/TP4) for Qwen/Qwen3-Next-80B-A3B-Instruct-FP8
on MI355X using CK + CK-TILE backends with splitK support.

Depends on:
- PR #2862 (CK bump for stride fix in CK-TILE blockscale)
- PR #2541 (splitK support for CK/CK-TILE blockscale GEMMs)
- PR #2487 (AQLayout tunable for CK-TILE blockscale 8-warp kernels)

* refactor(triton): rename gated RMSNorm+FP8 op to fused_rms_gated_fp8_group_quant

Colocate the gated RMSNorm + FP8 group quant path with the other fused FP8
ops. The Triton kernel is now _fused_rms_gated_fp8_group_quant_kernel in
_triton_kernels/quant/fused_fp8_quant.py; the Python entry point is
fused_rms_gated_fp8_group_quant in quant/fused_fp8_quant.py, with a docstring
that contrasts it with fused_rms_fp8_group_quant. Remove the old
rmsnorm_input_quant_fp8 module and rms_norm_input_quant_fp8 kernel file.
Re-export the new symbol and helpers (get_fp8_min_max_bounds,
calc_rows_per_block) from aiter.ops.triton.quant. Rename the test file to
test_fused_rms_gated_fp8_group_quant.py and update test.sh.

BREAKING CHANGE: rmsnorm_input_quant_fp8 is removed; use
fused_rms_gated_fp8_group_quant instead.

Made-with: Cursor

* Retune blockscale GEMM configs to fix invalid kernelId+splitK combinations

Full retune of all 1482 shapes on MI355X (gfx950, cu_num=256).
Key changes:
- SplitK usage dropped from 613 to 88 CK shapes (splitK > 0)
- All shapes validated via --run_config (1482/1482 OK)
- E2e perf: 2-8% output throughput improvement vs untuned heuristic

* [Bug] pa_mqa_logits: mask OOB stores on OutLogits_buffer

The gluon `_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle` and
`_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx` kernels have 10
`buffer_store(ptr=OutLogits_buffer, ...)` call sites that are missing the
upper-bound mask present on their sibling stores.  When
`context_length == max_model_len` (the last-token position in a long-
context decode step), `split_context_length` is rounded UP to a
`KVBlockSize` multiple at line 427 and the final prefix/suffix store then
writes up to `ChunkKPerStage` float32 elements past the logical row end.
With `stride_out_batch == max_model_len`, those writes cross into the
next row / the next allocation, causing intermittent HIP memory-access
faults on gfx950 during DeepSeek V3.2 MTP decoding.

This change adds `mask=<offset> < max_model_len` to every unmasked
`buffer_store` on `OutLogits_buffer` in both preshuffle kernels, matching
the pattern of their already-masked neighbours.  The existing
`tl.where(..., -inf)` masking of the *values* is preserved; the only
behavioural change is that out-of-row lanes no longer emit buffer
stores.  Hardware overhead is negligible: `buffer_store` with a predicate
is the same SMEM descriptor path as the unmasked variant, just with a
VCC mask setup.

Repro + end-to-end fix evidence: see PR description.

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>

* style: fix Black formatting

* style: fix Black formatting (Python 3.12 compatible)

* ci: replace deprecated zmq package with pyzmq

The `zmq` meta-package fails to install on some CI runners because
it cannot resolve the `pyzmq` dependency. Use `pyzmq` directly,
which is the actual package providing ZeroMQ bindings for Python.

Fixes Triton Test Shard 7 setup failures.

* ci: increase pip retries and timeout for CI reliability

Set pip global retries=15 and timeout=120s in build_aiter_triton.sh
to handle transient PyPI network failures on self-hosted runners.
Shard 5/7 failures were caused by RemoteDisconnected during pip install.

* ci: make pyzmq install non-blocking in triton test setup

pyzmq is only used by aiter.dist.shm_broadcast, not by any triton
test. When PyPI is unreachable on self-hosted runners, the pyzmq
install failure should not block the entire CI shard.

Split pyzmq into a separate pip install with || fallback so triton
tests can proceed even when PyPI connectivity is degraded.

* ci: retry pip install individually on batch failure

When batch pip install fails (e.g., PyPI connectivity issues on
self-hosted runners), retry each package individually. Only pyzmq
is allowed to fail silently since it's only used by
aiter.dist.shm_broadcast and not required by any CI test suite.

Critical packages (pandas, einops, numpy) must still succeed.

* [MLA] Fix nhead=32 non-persistent decode crash on gfx950

Commit c849fd5 ("Add bf16 MLA decode kernel for gqa_ratio=64,
qseqlen=1 (non-persistent)") zeroed ptr_RP and out_16_nosplit for all
non-persistent dispatch. The legacy QH16 ASM kernel used for nhead=32
(MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co) still writes
directly to the output buffer via ptr_RP when kv_split==1.
Dereferencing nullptr causes a GPU memory access fault during CUDA
graph capture on MI355X (gfx950) with DeepSeek-V3.2 at TP4.

Fix:
- Conditionally restore ptr_RP and out_16_nosplit in the non-persistent
  path for legacy kernels (gqa_ratio * max_seqlen_q <= 64) while
  keeping nullptr for newer kernels (e.g. gqa_ratio=64).
- Restore the bf16 nhead in [32,64] early-return after stage1 when
  num_kv_splits==1 to prevent stage2 from overwriting the kernel's
  direct output.

Tested on MI355X TP4 with deepseek-ai/DeepSeek-V3.2 (nhead=32):
- No crash during CUDA graph capture
- Correct GSM8K accuracy

Made-with: Cursor

* revert: remove #2983 (MLA nhead=32 fix) — causes test_mla CI failures

Reverting cherry-pick of #2983 from this bulk merge. The MLA nhead=32
non-persistent decode fix causes deterministic test_mla k_cache and
mla_decode-absorb precision failures on CI MI35X runners (Shard 1 & 2).

#2983 should go through its own PR with proper CI validation by the
original author (frida-andersson).

* fix: restore tuple unpack for FlyDSL fused-quant stage1 return

flydsl_moe_stage1 returns (out, out_scale_sorted) when the kernel uses
fused fp4/fp8 quantization. The tuple unpack logic was removed during
earlier refactoring but the kernel behavior was not changed, causing
fused_moe_2stages to crash with:
  AttributeError: 'tuple' object has no attribute 'view'

Restore the unpack: detect tuple return, extract tensor and scale,
handle fp4 byte-packing trim, and skip redundant Python-side requant
when the kernel already produced sorted scales.

* Revert leaked changes from excluded PRs #2457/#2547/#2687 in fused_moe.py

- Restore import to match main: use `from aiter import
  fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd` instead of
  importing from internal triton path and fp4_utils
- Replace all fp4_utils.moe_mxfp4_sort() calls with mxfp4_moe_sort_fwd()
  using correct parameter names (cols= instead of block_size=)
- Remove all moe_buf preallocated buffer additions (PR #2687 rejected):
  parameter defaults, if-guards, and pass-throughs in _moe_sorting_impl,
  moe_sorting, fused_moe, fused_moe_fake, and fused_moe_
- Fix moe_sorting_dispatch_policy type annotation: bool -> int in
  fused_moe_fake and fused_moe_
- Remove moe_buf pass-through test from test_moe_sorting.py
- Preserve legitimate fp4_utils usage (mxfp4_to_f32, e8m0_to_f32) with
  local imports in stage1/stage2 fallback functions

* fix: restore fp4_utils.moe_mxfp4_sort for new code paths (different output layout than mxfp4_moe_sort_fwd)

* style: fix Black formatting for local imports

* fix: remove rejected W4A6 QuantType remap from fused_moe_dp_shared_expert

Lingpeng explicitly rejected this change (from excluded PR #2457).
Reverts the QuantType.No -> per_1x32 remap for fp4x2 weights.

* fix: restore silently-reverted main features from bad merge resolution

aiter/fused_moe.py:
- Restore to origin/main. Per sunway513's own comment, #2457 and #2547
  were excluded from this bulk merge; per valarLip, #2687 was rejected.
  No source PR should land changes in this file. The previous state
  (+110/-119 vs main) was collateral damage from auto-resolved conflicts
  taking older sides, which silently reverted #2262 (xbf16 asm fmoe path),
  #2726 (FlyDSL a8w4 MoE wrapper params + fuse_quant), #2658 (CK fp8
  blockscale splitk tuner support), and #2620 (mxfp4_moe_sort_hip,
  flagged by valarLip).

op_tests/test_gemm_a8w8_blockscale.py:
- Replace with a clean 3-way merge of origin/main + #2541. Now +55/-0
  vs main, matching #2541's actual contribution exactly. The previous
  state was silently reverting #2645 (CK GEMM multi-arch + test infra:
  TEST_NUM_ITERS, --csv/--output args, kernel_name= param).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* chore: remove #2464 from bulk merge per author request

@xaguilar-amd asked to drop #2464 (CK MoE tuner bug fixes) from this
bulk merge — they don't need it for the uplift.

Verified that #2464 is the only PR in this bulk merge touching
aiter/jit/core.py and aiter/utility/mp_tuner.py: the diff between the
branch and origin/main on those files is exactly #2464's +9/-1 and
+5/-0, with no other PR content mixed in. Restoring both files to
origin/main therefore drops #2464 cleanly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: vecheruk-amd <vecheruk@amd.com>
Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com>
Co-authored-by: Sami Remes <samremes@amd.com>
Co-authored-by: Li <chuali@amd.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: samremes <181322991+samremes@users.noreply.github.com>
Co-authored-by: hellozhuo <zhuo.su@amd.com>
Co-authored-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com>
Co-authored-by: Niklas Holmberg <nholmber@users.noreply.github.com>
Co-authored-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: frida-andersson <fanderss@amd.com>
Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Liang-jianhao97 pushed a commit that referenced this pull request May 7, 2026
The legacy QH16 m32x1_n16x1 ASM kernel (gqa_ratio=32, bf16/bf16,
non-persistent, decode qseqlen=1) writes its output directly via
ptr_RP when kv_split==1. Upstream passes ptr_RP=nullptr and
out_16_nosplit=0, causing GPU memory faults on gfx950 (DeepSeek-V3.2
TP4 hits this with nhead=32).

Fix:
- C++: set ptr_RP and out_16_nosplit only when gqa_ratio==32 AND
  max_seqlen_q==1 (the exact legacy kernel condition). Other
  non-persistent kernels (v3, stage1) use split-reduce and expect
  ptr_RP = nullptr, so they are unaffected.
- Python: reuse output buffer for logits and skip stage2 only when
  nhead==32 and max_seqlen_q==1 (matches the C++ gate).

Tested on MI355X (gfx950): nhead=8,16,32,64,128 all pass.
bf16/bf16, ctx_lens=[256,1024], batch=[1,4,16].

Supersedes: #2999 (broken — tile mismatch, 85% wrong output)

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: azaidy <aliasger.zaidy@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants